{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training an Autoencoder With TensorNodes\n",
    "\n",
    "When working with TensorNodes you can train them just like normal parts of a Nengo model. The only deviation from the standard procedure is the definition of the custom classes which encapsulate the TensorFlow portions. \n",
    "\n",
    "In this example we will illustrate this trainability by training an autoencoder on the MNIST dataset. The autoencoder takes the input in with a dimensionality of `784` (28\\*28) and reduces it to a dimensionality of `36`. This is the encoding phase where the network is effectively compressing the input. The decode phase then takes that `36` dimensional representation and attempts to reconstruct the original input with it.\n",
    "\n",
    "We show at the bottom of the notebook how the training changes the output. The output starts off as pure noise and gradually looks more and more like the input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import nengo\n",
    "import nengo_dl\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import MNIST data\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data\", one_hot=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The network architecture consists of one input layer followed by 4 densely connected layers. The layers architecture is mirrored such that:\n",
    "\n",
    "```\n",
    "Layer 1: 784 Elements - Input\n",
    "Layer 2: 128 Elements - Encode 1\n",
    "Layer 3: 36 Elements - Encode 2\n",
    "Layer 4: 128 Elements - Decode 1\n",
    "Layer 5: 784 Elements - Decode 2/Output\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Network Parameters\n",
    "n_hidden_1 = 128 # 1st layer num features\n",
    "n_hidden_2 = 36 # 2nd layer num features\n",
    "n_input = 784 # MNIST data input (img shape: 28*28)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The TensorNodes are broken up into two classes: an Encoder, which compresses the input, and a Decoder, which decompresses the output to attempt to recreate the original.\n",
    "\n",
    "Both of the TensorNode types consist of two fully connected (dense) layers which are in turn connected to each other"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Building the encoder\n",
    "class Encoder(object):\n",
    "    def pre_build(self, shape_in, shape_out):\n",
    "        self.n_mini = shape_in[0]\n",
    "        self.size_in = shape_in[1]\n",
    "        self.size_out = shape_out[1]\n",
    "        \n",
    "    \n",
    "    def __call__(self, t, x):    \n",
    "        dense = tf.contrib.layers.fully_connected(x, n_hidden_1)\n",
    "        dense = tf.contrib.layers.fully_connected(dense, n_hidden_2)\n",
    "\n",
    "        return dense\n",
    "\n",
    "\n",
    "# Building the encoder\n",
    "class Decoder(object):\n",
    "    def pre_build(self, shape_in, shape_out):\n",
    "        self.n_mini = shape_in[0]\n",
    "        self.size_in = shape_in[1]\n",
    "        self.size_out = shape_out[1]\n",
    "    \n",
    "    def __call__(self, t, x):    \n",
    "        dense = tf.contrib.layers.fully_connected(x, n_hidden_1)\n",
    "        dense = tf.contrib.layers.fully_connected(dense, self.size_out)\n",
    "\n",
    "        return dense\n",
    "\n",
    "\n",
    "with nengo.Network() as net:\n",
    "    net.config[nengo.Connection].synapse = None\n",
    "    net.config[nengo.Ensemble].neuron_type = nengo.RectifiedLinear()\n",
    "    net.config[nengo.Ensemble].gain = nengo.dists.Choice([1])\n",
    "    net.config[nengo.Ensemble].bias = nengo.dists.Choice([0])\n",
    "    \n",
    "     # create node to feed in images\n",
    "    inp = nengo.Node(nengo.processes.PresentInput(mnist.test.images, 0.001))\n",
    "    \n",
    "    # create TensorNodes to insert into the network\n",
    "    tf_encode = nengo_dl.TensorNode(Encoder(), size_in=n_input, size_out=n_hidden_2, label='H1')\n",
    "    tf_decode = nengo_dl.TensorNode(Decoder(), size_in=n_hidden_2, size_out=n_input, label='H2')\n",
    "    \n",
    "    # connecting all the nodes together\n",
    "    nengo.Connection(inp, tf_encode)\n",
    "    nengo.Connection(tf_encode, tf_decode)\n",
    "    \n",
    "    # defining probes\n",
    "    input_probe = nengo.Probe(inp)\n",
    "    encode_probe = nengo.Probe(tf_encode)\n",
    "    decode_probe = nengo.Probe(tf_decode)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Luckily for this network the training data for output and input are the same, therefore the same data can be used for both."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_inputs = {inp: mnist.train.images[:, None, :]}\n",
    "train_targets = {decode_probe: train_inputs[inp]}\n",
    "test_inputs = {inp: mnist.test.images[:, None, :]}\n",
    "test_targets = {decode_probe: test_inputs[inp]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We break the training up into single epochs so that we can visualize the effect of the training over time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n_examples = 100\n",
    "losses = []\n",
    "learning_rate = 0.05\n",
    "repetitions = 5\n",
    "with nengo_dl.Simulator(net, minibatch_size=n_examples, device=\"/cpu:0\") as sim:\n",
    "    optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)\n",
    "    \n",
    "    for e in range(repetitions):\n",
    "        losses.append(sim.loss(test_inputs, test_targets, 'mse'))\n",
    "        print(losses)\n",
    "        \n",
    "        sim.step(input_feeds={inp: test_inputs[inp][:n_examples]})\n",
    "        \n",
    "        sim.train(train_inputs, train_targets, optimizer)\n",
    "        \n",
    "    sim.step(input_feeds={inp: test_inputs[inp][:n_examples]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The plot below shows the input on the top row, the compressed representation in the middle, and then the reconstructed output on the bottom. We can see that as training progresses (moving left to right) the reconstructions become more and more accurate. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure, axis = plt.subplots(3, repetitions+1, figsize=(10, 5))\n",
    "for i in range(repetitions+1):\n",
    "    axis[0][i].imshow(np.reshape(sim.data[input_probe][0, i], (28, 28)))\n",
    "    axis[1][i].imshow(np.reshape(sim.data[encode_probe][0, i], (6, 6)))\n",
    "    axis[2][i].imshow(np.reshape(sim.data[decode_probe][0, i], (28, 28)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
